# _*_ coding: utf-8 _*_
# @Date : 2023/3/14 18:47
# @Author : Paul
# @File : decision_tree.py
# @Description : 决策树算法
import matplotlib.pyplot as plt
from core.beans.param_train_result import ParamTrainResult

from core.utils.log_util import LogUtil
from core.beans.classifier_result import ClassifierResult
from core.utils.date_util import DateUtil
from sklearn.tree import DecisionTreeClassifier
from classifiers.classifier import Classifier
import sys
import json

from core.utils.string_utils import StringUtils


class TWDecisionTree(Classifier):

    def __init__(self,
                 app_name="clusters",
                 data_source_id=None,
                 table_name=None,
                 feature_cols=None,
                 class_col=None,
                 train_size=None,
                 param = None
                 ):
        """
        初始化
        :param app_name:
        :param data_source_id:
        :param table_name:
        :param feature_cols:
        :param class_col:
        :param train_size:
        """
        super(TWDecisionTree, self).__init__(app_name=app_name,
                 data_source_id=data_source_id,
                 table_name=table_name,
                 feature_cols=feature_cols,
                 class_col=class_col,
                 train_size=train_size,
                 param=param)
        self.IS_MODEL_EVAL = True  # 默认：不需要评估模型
        # 预测效果分布图
        self.tree_pred_image = "descion_tree_pred_" + app_name + "_" + DateUtil.getCurrentDateSimple()

    def initModel(self):
        """
        初始化模型
        """
        algoParam = self.param["algoParam"]
        criterion = "gini" if StringUtils.isBlack(algoParam["criterion"]) else algoParam["criterion"]
        randomState = None if StringUtils.isBlack(algoParam["randomState"]) else int(algoParam["randomState"])
        splitter = "best" if StringUtils.isBlack(algoParam["splitter"]) else algoParam["splitter"]
        maxDepth = None if StringUtils.isBlack(algoParam["maxDepth"]) else int(algoParam["maxDepth"])
        minSamplesSplit = 2 if StringUtils.isBlack(algoParam["minSamplesSplit"]) else int(algoParam["minSamplesSplit"])
        minSamplesLeaf = 1 if StringUtils.isBlack(algoParam["minSamplesLeaf"]) else int(algoParam["minSamplesLeaf"])

        self.clf = DecisionTreeClassifier(criterion=criterion,
                                          random_state=randomState,
                                          splitter=splitter,
                                          max_depth=maxDepth,
                                          min_samples_split=minSamplesSplit,
                                          min_samples_leaf=minSamplesLeaf)

    def buildModel(self, train_data):
        """
        训练模型
        """
        Xtrain = train_data[0]
        Ytrain = train_data[1]
        self.clf = self.clf.fit(Xtrain, Ytrain)

    def evalModel(self, train_data, test_data):
        """
        评估模型
        """
        Xtest = test_data[0]
        Ytest = test_data[1]
        score_ = self.clf.score(Xtest, Ytest)
        var_importance = [*zip(self.feature_cols, self.clf.feature_importances_)]

        import graphviz
        from sklearn import tree
        dot_data = tree.export_graphviz(self.clf,
                                        feature_names=self.feature_cols,
                                        class_names=self.labels,
                                        filled=True,
                                        rounded=True)
        graph = graphviz.Source(dot_data)
        graph.view(filename=self.tree_pred_image, directory=self.image_path)
        # graph.save(self.cluster_pred_image)

        # 结束时间
        end_time = DateUtil.getCurrentDate()
        cost_second = DateUtil.diffMin(self.start_time, end_time)

        # 模型结果存入mysql
        algo_result = ClassifierResult(self.param["id"],
                                    "decision_tree",
                                    self.param,
                                    self.app_name,
                                    self.info,
                                    self.describe,
                                    self.image_path + self.tree_pred_image + ".pdf",
                                    var_importance,
                                    score_,
                                    "sucess",
                                    self.start_time,
                                    end_time,
                                    cost_second)
        LogUtil.saveClassifierResult(self.meta_data_source, algo_result)

    def paramTrain(self):
        """
        超参数训练
        :return:
        """
        # 获取超参数训练参数
        param_train = self.param["paramTrain"]
        param_name = None if StringUtils.isBlack(param_train["paramName"]) else str(param_train["paramName"])
        param_start_value = None if StringUtils.isBlack(param_train["paramStartValue"]) else int(param_train["paramStartValue"])
        param_end_value = None if StringUtils.isBlack(param_train["paramEndValue"]) else int(param_train["paramEndValue"])
        param_range_value = None if StringUtils.isBlack(param_train["paramRangeValue"]) else int(param_train["paramRangeValue"])

        #测试数据
        train_data, test_data = self.getModelData()
        Xtrain = train_data[0]
        Ytrain = train_data[1]
        Xtest = test_data[0]
        Ytest = test_data[1]
        eval_value_list = []
        if param_name == None or param_start_value is None or param_end_value is None or param_range_value is None or param_start_value==1:
            error_info = "请确认参数必须为整数，且参数起始值不能为1"

            # 结束时间
            end_time = DateUtil.getCurrentDate()
            cost_second = DateUtil.diffMin(self.start_time, end_time)
            # 模型结果存入mysql
            param_train_result = ParamTrainResult(self.param["id"],
                                                  "decision_tree",
                                                  self.param,
                                                  self.app_name,
                                                  error_info,
                                                  "failed",
                                                  self.start_time,
                                                  end_time,
                                                  cost_second)
            LogUtil.saveParamTrainResult(self.meta_data_source, param_train_result)
        elif param_name == "max_depth":
            for max_depth in range(param_start_value, param_end_value, param_range_value):
                self.clf = DecisionTreeClassifier(max_depth=max_depth)
                self.clf = self.clf.fit(Xtrain, Ytrain)
                score_ = self.clf.score(Xtest, Ytest)
                eval_value_list.append(score_)
            # 保存结果
            param_train_image = self.image_path + "descion_tree_param_train_" + DateUtil.getCurrentDateSimple() + ".png"
            fig, ax = plt.subplots(1, 1)
            ax.set_title("最大数据深度--准确率--超参数学习曲线")
            ax.plot([i for i in range(param_start_value, param_end_value, param_range_value)], eval_value_list)
            plt.savefig(param_train_image, dpi=300)
            # plt.show()

            # 结束时间
            end_time = DateUtil.getCurrentDate()
            cost_second = DateUtil.diffMin(self.start_time, end_time)
            # 模型结果存入mysql
            param_train_result = ParamTrainResult(self.param["id"],
                                               "decision_tree",
                                                self.param,
                                                self.app_name,
                                                param_train_image,
                                                "success",
                                                self.start_time,
                                                end_time,
                                                cost_second)
            LogUtil.saveParamTrainResult(self.meta_data_source, param_train_result)

if __name__ == '__main__':
    argv = sys.argv[1]
    # argv = "{\"algoParam\":{\"criterion\":\"gini\",\"maxDepth\":\"\",\"minSamplesLeaf\":\"\",\"minSamplesSplit\":\"\",\"randomState\":\"\",\"splitter\":\"best\"},\"appName\":\"kmeans_1\",\"classCols\":\"Survived\",\"dataSourceId\":\"9\",\"featureCols\":\"Pclass,Age,Sex,SibSp,Parch,Fare,Embarked\",\"id\":\"1679300369558\",\"preProcessMethodList\":[{\"preProcessMethod\":\"entropy\"}],\"tableName\":\"titanic\"}";
    param = json.loads(argv)
    app_name = param["appName"]
    data_source_id = param["dataSourceId"]
    table_name = param["tableName"]
    feature_cols = param["featureCols"]
    class_cols = param["classCols"]
    train_size = float(param["trainDataRatio"])
    class_cols_list = []
    if isinstance(class_cols, list):
        class_cols_list = class_cols
    else:
        class_cols_list.append(class_cols)
    classifier = TWDecisionTree(app_name=app_name,
                        data_source_id=data_source_id,
                        table_name=table_name,
                        feature_cols=feature_cols,
                        class_col=class_cols_list,
                        train_size=0.7,
                        param=param)
    if "paramTrain" not in param.keys():
        classifier.execute()
    else:
        classifier.paramTrain()

